package mysqlclient

import (
	"context"
	"crypto/aes"
	"crypto/cipher"
	"errors"
	"fmt"
	"gitee.com/lsy007/mysqlclient/param"
	"github.com/golang/protobuf/proto"
	"github.com/smallnest/rpcx/client"
	"github.com/smallnest/rpcx/protocol"
)

type Header struct {
	ProId  string // 服务所属项目的唯一ID
	Region string // 服务所在区域信息
	Server string // 服务名称
}

type Service struct {
	Addr            string // 当前区域数据库中间件服务地址
	SqlScenes       string // 当前服务操作数据库类型，region/global/order
	RpcRegisterName string // 数据库中间件服务 rpc 注册名称，固定 "MysqlHandle"
	Encrypt         string // 是否开启数据加密请求，值为 on 则开启，默认不开启
	Secret          string // 数据库中间件 AES 加密密钥
}

type JoinInfo struct {
	Type  string
	Table string
	Alias string
	On    string
}

type Condition struct {
	Column   string
	Operator string
	Value    interface{}
}

type ConnCondition struct {
	ConnType  string
	Condition []*Condition
}

type ChildCondition struct {
	Column    string
	Operator  string
	Table     string
	Condition []*Condition
}

type ConnChildCondition struct {
	ConnType  string
	Condition []*ChildCondition
}

type DBInfo struct {
	Table         string
	Alias         string
	Func          string
	Field         []string
	Cols          []string
	Omit          []string
	Id            int64
	Where         []*Condition
	And           *ConnCondition
	Or            *ConnCondition
	AndChild      *ConnChildCondition
	OrChild       *ConnChildCondition
	AutoCondition int64
	Order         string
	Start         int64
	Rows          int64
	Group         string
	Having        string
	Model         proto.Message
	Type          int64
	Join          []*JoinInfo
	Master        int64
}

type Client struct {
	Header    *Header
	Service   *Service
	RequestId string
	DBInfo    *DBInfo
	SqlScenes string // 自定义请求数据库
	Region    string // 自定义区域
}

type SessionClient struct {
	Header    *Header
	Service   *Service
	RequestId string
	DBInfo    []*DBInfo
	SqlScenes string // 自定义请求数据库
	Region    string // 自定义区域
}

type Response struct {
	RequestId string
	Region    string
	Has       bool
	Value     interface{}
	Model     []byte
}

func InitConfig(h *Header, s *Service) (*Client, *SessionClient, error) {
	// 1. 验证并设置头部信息
	if h.ProId == "" || h.Region == "" || h.Server == "" {
		return nil, nil, errors.New("mysqlClient has a default Header configuration that is empty")
	}
	// 2. 验证并设置 Mysql 中间件服务信息
	if s.Addr == "" || s.SqlScenes == "" || s.Encrypt == "" || s.Secret == "" {
		return nil, nil, errors.New("mysqlService link information is incomplete")
	}
	return &Client{Header: h, Service: s}, &SessionClient{Header: h, Service: s}, nil
}

func (c *Client) SingleRequest() (reply param.Response, err error) {
	if err = baseRequest(c, "SingleRequestAction", &reply); err != nil {
		return
	}
	if reply.ErrMsg != "" {
		err = errors.New(reply.ErrMsg)
		return
	}
	return
}

func (c *Client) AllRegionSingleRequest() (reply param.Response, err error) {
	allReply := param.AllRegionResponse{}
	if err = baseRequest(c, "AllRegionSingleRequestAction", &allReply); err != nil {
		return
	}
	for _, v := range allReply.Reply {
		// （1）判断是否存在业务错误
		if v.ErrMsg != "" {
			err = errors.New(fmt.Sprintf("region %s action error：%s", v.Region, v.ErrMsg))
			return
		}
		// （2）判断是否存在值
		if v.Has {
			reply.Has = true
			reply.Model = v.Model
			reply.Type = v.Type
			reply.RequestId = v.RequestId
			reply.Region = v.Region
			return
		} else if v.Value != nil {
			if v.Value.Int != 0 || v.Value.Float != 0 || v.Value.String_ != "" || len(v.Value.IntSlice) > 0 || len(v.Value.FloatSlice) > 0 || len(v.Value.StringSlice) > 0 {
				reply.RequestId = v.RequestId
				reply.Region = v.Region
				reply.Value = v.Value
				return
			}
		}
	}
	return
}

func (c *Client) JoinRequest() (reply param.Response, err error) {
	if err = baseRequest(c, "JoinRequestAction", &reply); err != nil {
		return
	}
	if reply.ErrMsg != "" {
		err = errors.New(reply.ErrMsg)
		return
	}
	return
}

func (c *Client) AllRegionJoinRequest() (reply param.Response, err error) {
	allReply := param.AllRegionResponse{}
	if err = baseRequest(c, "AllRegionJoinRequestAction", &allReply); err != nil {
		return
	}
	for _, v := range allReply.Reply {
		// （1）判断是否存在业务错误
		if v.ErrMsg != "" {
			err = errors.New(fmt.Sprintf("region %s action error：%s", v.Region, v.ErrMsg))
			return
		}
		// （2）判断是否存在值
		if v.Has {
			reply.Has = true
			reply.Model = v.Model
			reply.Type = v.Type
			reply.RequestId = v.RequestId
			reply.Region = v.Region
		} else if v.Value != nil {
			if v.Value.Int != 0 || v.Value.Float != 0 || v.Value.String_ != "" || len(v.Value.IntSlice) > 0 || len(v.Value.FloatSlice) > 0 || len(v.Value.StringSlice) > 0 {
				reply.RequestId = v.RequestId
				reply.Region = v.Region
				reply.Value = v.Value
				return
			}
		}
	}
	return
}

func (c *Client) SeedDatabaseRequest() (err error) {
	var reply param.Response
	if err = baseRequest(c, "SeedDatabaseRequestAction", &reply); err != nil {
		return
	}
	if reply.ErrMsg != "" {
		return errors.New(reply.ErrMsg)
	}
	return nil
}

func (c *Client) DatabaseInfoRequest() (data []byte, err error) {
	var reply param.Response
	if err = baseRequest(c, "DatabaseInfoRequestAction", &reply); err != nil {
		return
	}
	if reply.ErrMsg != "" {
		err = errors.New(reply.ErrMsg)
		return
	}
	return reply.Model, nil
}

func (s *SessionClient) SessionRequest() (reply param.Response, err error) {
	if err = sessRequest(s, "SessionRequestAction", &reply); err != nil {
		return
	}
	if reply.ErrMsg != "" {
		err = errors.New(reply.ErrMsg)
		return
	}
	return
}

func (s *SessionClient) AllRegionSessionRequest() (err error) {
	reply := param.AllRegionResponse{}
	if err = sessRequest(s, "AllRegionSessionRequestAction", &reply); err != nil {
		return
	}
	// 判断是否存在业务错误
	for _, v := range reply.Reply {
		if v.ErrMsg != "" {
			err = errors.New(fmt.Sprintf("region %s action error：%s", v.Region, v.ErrMsg))
			return
		}
	}
	return
}

func baseRequest(c *Client, fun string, reply interface{}) (err error) {
	// 0. 判断配置是否为空
	if c.Header == nil || c.Service == nil {
		return errors.New("mysqlClient config is empty")
	}
	//1. 构建数据库信息
	var dbInfo *param.DBInfo
	if fun != "SeedDatabaseRequestAction" {
		if dbInfo, err = buildDbInfo(c.DBInfo); err != nil {
			return
		}
	}
	//2. 组装请求头
	header := buildHeader(c.RequestId, c.SqlScenes, c.Region, c.Header, c.Service)
	//3. 判断是否需要对 Model 数据进行 AES 加密,组装请求参数
	args := param.SingleRequest{Header: header, DBInfo: dbInfo}
	if fun != "SeedDatabaseRequestAction" && fun != "DatabaseInfoRequestAction" && c.Service.Encrypt == "on" {
		if len(header.RequestId) < 16 {
			return errors.New("the length of the requestId must be greater than 16")
		}
		iv := []byte(header.RequestId[0:aes.BlockSize])
		if args.DBInfo.Model, err = aesCTREncrypt(args.DBInfo.Model, []byte(c.Service.Secret), iv); err != nil {
			return
		}
	}
	//4. 请求数据
	return Request(c.Service.Addr, &args, reply, c.Service.RpcRegisterName, fun)
}

func sessRequest(s *SessionClient, fun string, reply interface{}) (err error) {
	// 0.判断配置是否为空
	if s.Header == nil || s.Service == nil {
		return errors.New("mysqlClient config is empty")
	}
	// 1. 循环获取数据库信息切片
	dbInfo := make([]*param.DBInfo, 0)
	for _, v := range s.DBInfo {
		var db *param.DBInfo
		if db, err = buildDbInfo(v); err != nil {
			return
		}
		dbInfo = append(dbInfo, db)
	}
	// 2. 组装请求头
	header := buildHeader(s.RequestId, s.SqlScenes, s.Region, s.Header, s.Service)
	//3. 判断是否需要对 Model 数据进行 AES 加密,组装请求参数，组装请求参数
	args := &param.SessionRequest{Header: header, DBInfo: dbInfo}
	if s.Service.Encrypt == "on" {
		if len(header.RequestId) < 16 {
			return errors.New("the length of the requestId must be greater than 16")
		}
		iv := []byte(header.RequestId[0:aes.BlockSize])
		secret := []byte(s.Service.Secret)
		for k, v := range args.DBInfo {
			if args.DBInfo[k].Model, err = aesCTREncrypt(v.Model, secret, iv); err != nil {
				return
			}
		}
	}
	//4. 请求数据
	return Request(s.Service.Addr, args, reply, s.Service.RpcRegisterName, fun)
}

func buildHeader(requestId, sqlScenes, region string, h *Header, s *Service) *param.Header {
	// 1. 判断是否注入区域及数据库信息
	if sqlScenes == "" {
		sqlScenes = s.SqlScenes
	}
	if region == "" {
		region = h.Region
	}
	// 2. 组装头部信息
	header := &param.Header{
		RequestId: requestId,
		// 根据 env 配置注入以下头部信息
		ProId:     h.ProId,
		Server:    h.Server,
		Region:    region,
		SqlScenes: sqlScenes,
	}
	// 3. 判断是否开启 AES 请求加密
	if s.Encrypt == "on" {
		header.Encrypt = 1
	}
	return header
}

func buildDbInfo(db *DBInfo) (dbInfo *param.DBInfo, err error) {
	// 0. 初始化 dbInfo，并对 Model进行 protobuf 编码处理
	dbInfo = &param.DBInfo{Table: db.Table, Func: db.Func, Type: db.Type, Alias: db.Alias, Field: db.Field, Cols: db.Cols, Omit: db.Omit, Id: db.Id,
		AutoCondition: db.AutoCondition, Order: db.Order, Start: db.Start, Rows: db.Rows, Group: db.Group, Having: db.Having}
	if db.Model != nil {
		if dbInfo.Model, err = proto.Marshal(db.Model); err != nil {
			return
		}
	}
	// 1. 判断是否存在 Where 条件
	if len(db.Where) != 0 {
		if dbInfo.Where, err = getCondition(db.Where); err != nil {
			return
		}
	}
	// 2. 判断是否存在 And 条件
	if db.And != nil {
		and := &param.ConnCondition{ConnType: db.And.ConnType}
		if and.Condition, err = getCondition(db.And.Condition); err != nil {
			return
		}
		dbInfo.And = and
	}
	// 3. 判断是否存在 Or 条件
	if db.Or != nil {
		or := &param.ConnCondition{ConnType: db.Or.ConnType}
		if or.Condition, err = getCondition(db.Or.Condition); err != nil {
			return
		}
		dbInfo.Or = or
	}
	// 4. 判断是否存在 AndChild 条件
	if db.AndChild != nil {
		andChild := &param.ConnChildCondition{ConnType: db.AndChild.ConnType}
		if andChild.Condition, err = getChildCondition(db.AndChild.Condition); err != nil {
			return
		}
		dbInfo.AndChild = andChild
	}
	// 5. 判断是否存在 OrChild 条件
	if db.OrChild != nil {
		orChild := &param.ConnChildCondition{ConnType: db.OrChild.ConnType}
		if orChild.Condition, err = getChildCondition(db.OrChild.Condition); err != nil {
			return
		}
		dbInfo.OrChild = orChild
	}
	// 6.判断是否存在 Join 条件
	if len(db.Join) != 0 {
		join := make([]*param.JoinInfo, 0)
		for _, v := range db.Join {
			j := param.JoinInfo{Type: v.Type, Table: v.Table, Alias: v.Alias, On: v.On}
			join = append(join, &j)
		}
		dbInfo.Join = join
	}
	return
}

func getChildCondition(cond []*ChildCondition) (condSlice []*param.ChildCondition, err error) {
	condSlice = make([]*param.ChildCondition, 0)
	for _, v := range cond {
		child := &param.ChildCondition{Column: v.Column, Operator: v.Operator, Table: v.Table}
		if child.Condition, err = getCondition(v.Condition); err != nil {
			return
		}
		condSlice = append(condSlice, child)
	}
	return
}

func getCondition(cond []*Condition) (condSlice []*param.Condition, err error) {
	condSlice = make([]*param.Condition, 0)
	for _, v := range cond {
		c := &param.Condition{Column: v.Column, Operator: v.Operator}
		if err = getValue(c, v.Value); err != nil {
			return
		}
		condSlice = append(condSlice, c)
	}
	return
}

func getValue(c *param.Condition, v interface{}) error {
	switch v.(type) {
	case int:
		c.VariableType = "int"
		c.Int = int64(v.(int))
	case int32:
		c.VariableType = "int"
		c.Int = int64(v.(int32))
	case int64:
		c.VariableType = "int"
		c.Int = v.(int64)
	case []int:
		c.VariableType = "int"
		intSlice := v.([]int)
		intList := make([]int64, 0)
		for _, i := range intSlice {
			intList = append(intList, int64(i))
		}
		c.IntSlice = intList
	case []int32:
		c.VariableType = "int"
		intSlice := v.([]int32)
		intList := make([]int64, 0)
		for _, i := range intSlice {
			intList = append(intList, int64(i))
		}
		c.IntSlice = intList
	case []int64:
		c.VariableType = "int"
		c.IntSlice = v.([]int64)
	case float64:
		c.VariableType = "float"
		c.Float = v.(float64)
	case []float64:
		c.VariableType = "float"
		c.FloatSlice = v.([]float64)
	case string:
		c.VariableType = "string"
		c.String_ = v.(string)
	case []string:
		c.VariableType = "string"
		c.StringSlice = v.([]string)
	default:
		return errors.New("条件值错误")
	}
	return nil
}

func Request(addr string, args interface{}, reply interface{}, registerName, fun string) (err error) {
	c := client.NewPeer2PeerDiscovery(addr, "")
	option := client.DefaultOption
	option.SerializeType = protocol.ProtoBuffer // 设置编解码格式为 protobuf 格式
	rpcClient := client.NewXClient(registerName, client.Failtry, client.RandomSelect, c, option)
	defer rpcClient.Close()
	return rpcClient.Call(context.Background(), fun, args, reply)
}

// 基于aes加密算法的ctr分组模式的加密
// key加解密秘钥，长度32字节；iv用requestId截取长度16字节的值
func aesCTREncrypt(plainText, key, iv []byte) ([]byte, error) {
	// 创建des算法的接口
	block, err := aes.NewCipher(key)
	if err != nil {
		return nil, err
	}
	// 创建CTR分组模式下的接口
	stream := cipher.NewCTR(block, iv)
	// 使用XORKeyStream完成加密,可以创建dst参数，也可不创建，若要创建注意与填充后的明文等长
	cipherText := make([]byte, len(plainText))
	stream.XORKeyStream(cipherText, plainText)
	// 返回密文
	return cipherText, nil
}
